/-
Copyright (c) 2025 Rémy Degenne. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Rémy Degenne, Lorenzo Luccioli
-/

import Mathlib.Probability.Kernel.Composition.Comp

/-!
# Risk of an estimator

An estimation problem is defined by a parameter space `Θ`, a data generating kernel `P : Kernel Θ 𝓧`
and a loss function `ℓ : Θ → 𝓨 → ℝ≥0∞`.
A (randomized) estimator is a kernel `κ : Kernel 𝓧 𝓨` that maps data to estimates of a quantity
of interest that depends on the parameter. Often the quantity of interest is the parameter itself
and `𝓨 = Θ`.
The quality of an estimate `y` when data comes from the distribution with parameter `θ` is measured
by the value of the loss function `ℓ θ y` (lower is better).

## Main definitions

The risk is the average loss of the estimator `κ` on data generated by `P` with parameter `θ`,
equal to `∫⁻ y, ℓ θ y ∂((κ ∘ₖ P) θ)`. We do not introduce a definition for that risk.

* `avgRisk ℓ P κ π`: the average of the risk of the estimator with respect to
  the prior `π : Measure Θ`.
* `bayesRisk ℓ P π`: the Bayes risk with respect to the prior `π`, minimum of the average
  risks over all estimators, that is over all Markov kernels `κ : Kernel 𝓧 𝓨`.
* `minimaxRisk ℓ P`: minimax risk, infimum over all estimators of the maximum over `θ` of the risk.

-/

open MeasureTheory
open scoped ENNReal

namespace ProbabilityTheory

variable {Θ 𝓧 𝓨 : Type*} {mΘ : MeasurableSpace Θ} {m𝓧 : MeasurableSpace 𝓧}

/-- The average risk of an estimator `κ` on an estimation task with loss `ℓ` and
data generating kernel `P` with respect to a prior `π`. -/
noncomputable
def avgRisk {m𝓨 : MeasurableSpace 𝓨}
    (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) (κ : Kernel 𝓧 𝓨) (π : Measure Θ) : ℝ≥0∞ :=
  ∫⁻ θ, ∫⁻ y, ℓ θ y ∂((κ ∘ₖ P) θ) ∂π

/-- The Bayes risk with respect to a prior `π`, defined as the infimum of the average risks of all
estimators. -/
noncomputable
def bayesRisk [MeasurableSpace 𝓨] (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) (π : Measure Θ) : ℝ≥0∞ :=
  ⨅ (κ : Kernel 𝓧 𝓨) (_ : IsMarkovKernel κ), avgRisk ℓ P κ π

/-- The minimax risk, defined as the infimum over estimators of the maximal risk of
the estimator. -/
noncomputable
def minimaxRisk [MeasurableSpace 𝓨] (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) : ℝ≥0∞ :=
  ⨅ (κ : Kernel 𝓧 𝓨) (_ : IsMarkovKernel κ), ⨆ θ, ∫⁻ y, ℓ θ y ∂((κ ∘ₖ P) θ)

variable {m𝓨 : MeasurableSpace 𝓨}
  {ℓ : Θ → 𝓨 → ℝ≥0∞} {P : Kernel Θ 𝓧} {κ : Kernel 𝓧 𝓨} {π : Measure Θ}

section Zero

@[simp]
lemma avgRisk_zero_left (ℓ : Θ → 𝓨 → ℝ≥0∞) (κ : Kernel 𝓧 𝓨) (π : Measure Θ) :
    avgRisk ℓ (0 : Kernel Θ 𝓧) κ π = 0 := by simp [avgRisk]

@[simp]
lemma avgRisk_zero_right (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) (π : Measure Θ) :
    avgRisk ℓ P (0 : Kernel 𝓧 𝓨) π = 0 := by simp [avgRisk]

@[simp]
lemma avgRisk_zero_prior (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) (κ : Kernel 𝓧 𝓨) :
    avgRisk ℓ P κ 0 = 0 := by simp [avgRisk]

@[simp]
lemma bayesRisk_zero_left [Nonempty 𝓨] (ℓ : Θ → 𝓨 → ℝ≥0∞) (π : Measure Θ) :
    bayesRisk ℓ (0 : Kernel Θ 𝓧) π = 0 := by simp [bayesRisk, iInf_subtype']

@[simp]
lemma bayesRisk_zero_right [Nonempty 𝓨] (ℓ : Θ → 𝓨 → ℝ≥0∞) (P : Kernel Θ 𝓧) :
    bayesRisk ℓ P (0 : Measure Θ) = 0 := by simp [bayesRisk, iInf_subtype']

@[simp]
lemma minimaxRisk_zero [Nonempty 𝓨] (ℓ : Θ → 𝓨 → ℝ≥0∞) :
    minimaxRisk ℓ (0 : Kernel Θ 𝓧) = 0 := by simp [minimaxRisk, iInf_subtype']

end Zero

section Empty

@[simp]
lemma avgRisk_of_isEmpty [IsEmpty 𝓧] : avgRisk ℓ P κ π = 0 := by
  simp [Subsingleton.elim P 0]

@[simp]
lemma avgRisk_of_isEmpty' [IsEmpty 𝓨] : avgRisk ℓ P κ π = 0 := by
  simp [Subsingleton.elim κ 0]

@[simp]
lemma avgRisk_of_isEmpty'' [IsEmpty Θ] : avgRisk ℓ P κ π = 0 := by
  simp [avgRisk]

@[simp]
lemma bayesRisk_of_isEmpty [IsEmpty 𝓧] : bayesRisk ℓ P π = 0 := by
  simp [bayesRisk]

@[simp]
lemma bayesRisk_of_isEmpty' [Nonempty 𝓧] [IsEmpty 𝓨] :
    bayesRisk ℓ P π = ∞ := by
  have : IsEmpty (Subtype (@IsMarkovKernel 𝓧 𝓨 m𝓧 m𝓨)) := by
    simp only [isEmpty_subtype]
    exact fun κ ↦ Subsingleton.elim κ 0 ▸ Kernel.not_isMarkovKernel_zero
  simp [bayesRisk, iInf_subtype']

@[simp]
lemma bayesRisk_of_isEmpty'' [IsEmpty Θ] [Nonempty 𝓨] :
    bayesRisk ℓ P π = 0 := by
  simp [bayesRisk, iInf_subtype']

@[simp]
lemma minimaxRisk_of_isEmpty [IsEmpty 𝓧] : minimaxRisk ℓ P = 0 := by
  simp [minimaxRisk, Subsingleton.elim P 0]

@[simp]
lemma minimaxRisk_of_isEmpty' [Nonempty 𝓧] [IsEmpty 𝓨] : minimaxRisk ℓ P = ∞ := by
  have : IsEmpty (Subtype (@IsMarkovKernel 𝓧 𝓨 m𝓧 m𝓨)) := by
    simp only [isEmpty_subtype]
    exact fun κ ↦ Subsingleton.elim κ 0 ▸ Kernel.not_isMarkovKernel_zero
  simp [minimaxRisk, iInf_subtype']

@[simp]
lemma minimaxRisk_of_isEmpty'' [Nonempty 𝓨] [IsEmpty Θ] : minimaxRisk ℓ P = 0 := by
  simp [minimaxRisk, iInf_subtype']
end Empty

end ProbabilityTheory
